package com.mycompany.sentimentanalysis;

import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;

import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.io.*;
import java.net.URL;

/**
 * @original author Alex Black Adapted by Jennifer Reese
 */
public class DL4JSentimentAnalysisExample {

    public static final String TRAINING_DATA_URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz";
    public static final String EXTRACT_DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_w2vSentiment/");
    public static final String GNEWS_VECTORS_PATH = "C:/Jenn Personal/Packt Data Science/Chapter 9 Text Analysis/GoogleNews-vectors-negative300.bin/GoogleNews-vectors-negative300.bin";

    public static void main(String[] args) throws Exception {

        getModelData();
        
        System.out.println("Total memory = " + Runtime.getRuntime().totalMemory());

        int batchSize = 50;
        int vectorSize = 300;
        int nEpochs = 5;
        int truncateReviewsToLength = 300;

        MultiLayerConfiguration sentimentNN = new NeuralNetConfiguration.Builder()
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
                .updater(Updater.RMSPROP)
                .regularization(true).l2(1e-5)
                .weightInit(WeightInit.XAVIER)
                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(1.0)
                .learningRate(0.0018)
                .list()
                .layer(0, new GravesLSTM.Builder().nIn(vectorSize).nOut(200)
                        .activation("softsign").build())
                .layer(1, new RnnOutputLayer.Builder().activation("softmax")
                        .lossFunction(LossFunctions.LossFunction.MCXENT).nIn(200).nOut(2).build())
                .pretrain(false).backprop(true).build();

        MultiLayerNetwork net = new MultiLayerNetwork(sentimentNN);
        net.init();
        net.setListeners(new ScoreIterationListener(1));

        WordVectors wordVectors = WordVectorSerializer.loadGoogleModel(new File(GNEWS_VECTORS_PATH), true, false);
        DataSetIterator trainData = new AsyncDataSetIterator(new SentimentExampleIterator(EXTRACT_DATA_PATH, wordVectors, batchSize, truncateReviewsToLength, true), 1);
        DataSetIterator testData = new AsyncDataSetIterator(new SentimentExampleIterator(EXTRACT_DATA_PATH, wordVectors, 100, truncateReviewsToLength, false), 1);

        for (int i = 0; i < nEpochs; i++) {
            net.fit(trainData);
            trainData.reset();

            Evaluation evaluation = new Evaluation();
            while (testData.hasNext()) {
                DataSet t = testData.next();
                INDArray dataFeatures = t.getFeatureMatrix();
                INDArray dataLabels = t.getLabels();
                INDArray inMask = t.getFeaturesMaskArray();
                INDArray outMask = t.getLabelsMaskArray();
                INDArray predicted = net.output(dataFeatures, false, inMask, outMask);

                evaluation.evalTimeSeries(dataLabels, predicted, outMask);
            }
            testData.reset();

            System.out.println(evaluation.stats());
        }
    }

    private static void getModelData() throws Exception {
        File modelDir = new File(EXTRACT_DATA_PATH);
        if (!modelDir.exists()) {
            modelDir.mkdir();
        }
        String archivePath = EXTRACT_DATA_PATH + "aclImdb_v1.tar.gz";
        File archiveName = new File(archivePath);
        String extractPath = EXTRACT_DATA_PATH + "aclImdb";
        File extractName = new File(extractPath);
        if (!archiveName.exists()) {
            FileUtils.copyURLToFile(new URL(TRAINING_DATA_URL), archiveName);
            extractTar(archivePath, EXTRACT_DATA_PATH);
        } else if (!extractName.exists()) {
            extractTar(archivePath, EXTRACT_DATA_PATH);
        }
    }

    private static final int BUFFER_SIZE = 4096;

    private static void extractTar(String dataIn, String dataOut) throws IOException {

        try (TarArchiveInputStream inStream = new TarArchiveInputStream(
                new GzipCompressorInputStream(new BufferedInputStream(new FileInputStream(dataIn))))) {
            TarArchiveEntry tarFile;
            while ((tarFile = (TarArchiveEntry) inStream.getNextEntry()) != null) {
                if (tarFile.isDirectory()) {
                    new File(dataOut + tarFile.getName()).mkdirs();
                } else {
                    int count;
                    byte data[] = new byte[BUFFER_SIZE];

                    FileOutputStream fileInStream = new FileOutputStream(dataOut + tarFile.getName());
                    BufferedOutputStream outStream=  new BufferedOutputStream(fileInStream, BUFFER_SIZE);
                    while ((count = inStream.read(data, 0, BUFFER_SIZE)) != -1) {
                        outStream.write(data, 0, count);
                    }
                }
            }
        }
    }
}